import rasterio
import numpy as np
import pandas as pd
import geopandas as gpd
from shapely.geometry import Point
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import argparse

parser = argparse.ArgumentParser(
    description='Compute and visualize population-weighted regions.'
)
parser.add_argument('--raster', 
                    default='data/mexico/mx-query-setup/pop-grid/mx_pd_2020_1km_Aggregated.tif', 
                    help='Path to population raster (GeoTIFF)'
)
parser.add_argument('--shapefile', 
                    default='data/mexico/mx-query-setup/priogrid-shape/priogrid_mex.shp',
                    help='Path to regions shapefile'
)
parser.add_argument('--out-img', 
                    default='output/mexico/mx-query-setup/priogrid_mex_wt_pop_2020.jpg',
                    help='Output image file'
)
parser.add_argument('--out-csv', 
                    default='output/mexico/mx-query-setup/priogrid_mex_wt_pop_2020.csv',
                    help='Output CSV summary file'
)
args = parser.parse_args()

def load_raster(path):
    """Load a raster file and return the dataset object."""
    print(f"Loading raster from {path}...")
    raster = rasterio.open(path)
    print(f"Raster loaded: CRS={raster.crs}, bounds={raster.bounds}")
    return raster


def load_shapefile(path):
    """Load a vector shapefile and return a GeoDataFrame."""
    print(f"Loading shapefile from {path}...")
    gdf = gpd.read_file(path)
    print(f"Shapefile loaded: {len(gdf)} features, CRS={gdf.crs}")
    return gdf


def raster_to_points(raster):
    """Convert raster cells to a GeoDataFrame of points with values."""
    print("Converting raster cells to point features...")
    band1 = raster.read(1)
    mask = raster.read_masks(1)
    rows, cols = np.where(mask != 0)
    xs, ys = raster.transform * (cols, rows)
    values = band1[rows, cols]

    df = pd.DataFrame({
        'population': values,
        'x': xs,
        'y': ys
    })
    gdf = gpd.GeoDataFrame(
        df,
        geometry=gpd.points_from_xy(df.x, df.y),
        crs=raster.crs
    )
    print(f"Generated {len(gdf)} point features from raster")
    return gdf


def reproject_to_match(gdf, target_gdf):
    """Reproject GeoDataFrame to match CRS of another GeoDataFrame."""
    print(f"Reprojecting data from CRS {gdf.crs} to match target CRS {target_gdf.crs}...")
    reprojected = gdf.to_crs(target_gdf.crs)
    print("Reprojection complete")
    return reprojected


def spatial_join(pop_pts, regions):
    """Join population points to regions based on spatial containment."""
    print("Performing spatial join (points within regions)...")
    joined = gpd.sjoin(pop_pts, regions, how="inner", predicate="within")
    print(f"Spatial join complete: {len(joined)} joined records")
    return joined


def summarize_population(joined):
    """Summarize total population per region (gid)."""
    print("Summarizing population by region...")
    summary = (
        joined
        .groupby('gid')['population']
        .sum()
        .reset_index()
    )
    print(f"Summary complete: {len(summary)} regions with population data")
    return summary


def attach_population(regions, summary):
    """Attach summarized population back to regions GeoDataFrame."""
    print("Attaching population summary to regions GeoDataFrame...")
    regions = regions.copy()
    regions['population'] = regions['gid'].map(
        summary.set_index('gid')['population']
    ).fillna(0)
    print("Attachment complete")
    return regions


def plot_population(regions, output_path, log_scale=True):
    """Plot regions colored by population (optionally log-scaled) and save to file."""
    print(f"Plotting population map to {output_path} (log scale={log_scale})...")
    fig, ax = plt.subplots(figsize=(10, 10))
    if log_scale:
        norm = LogNorm(
            vmin=regions['population'].replace(0, np.nan).min(),
            vmax=regions['population'].max()
        )
        regions.plot(column='population', ax=ax, legend=True, norm=norm)
    else:
        regions.plot(column='population', ax=ax, legend=True)
    ax.set_axis_off()
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close(fig)
    print("Map saved successfully")


def save_summary_csv(summary, output_path):
    """Save population summary DataFrame to CSV."""
    print(f"Saving population summary to CSV at {output_path}...")
    summary.to_csv(output_path, index=False)
    print("CSV saved successfully")


def main(
    raster_path: str,
    shapefile_path: str,
    output_image: str,
    output_csv: str
):
    """Run the geospatial population weighting workflow."""
    print("--- Starting population weighting workflow ---")

    # Load data
    raster = load_raster(raster_path)
    regions = load_shapefile(shapefile_path)

    # Convert raster to point features
    pop_pts = raster_to_points(raster)

    # Reproject regions to match points
    regions = reproject_to_match(regions, pop_pts)

    # Spatial join and summarize
    joined = spatial_join(pop_pts, regions)
    summary = summarize_population(joined)

    # Attach back to regions and plot
    regions_with_pop = attach_population(regions, summary)
    plot_population(regions_with_pop, output_image)

    # Save summary
    save_summary_csv(summary, output_csv)

    print("--- Workflow complete ---")


if __name__ == '__main__':

    main(
        args.raster,
        args.shapefile,
        args.out_img,
        args.out_csv
    )
